{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import numpy as np \n",
    "import seaborn as sns \n",
    "\n",
    "import csop2_helper as csop2\n",
    "import statsmodels.formula.api as smf\n",
    "import re \n",
    "import matplotlib.pyplot as plt \n",
    "from scipy.stats import entropy\n",
    "import scipy\n",
    "from tqdm import tqdm \n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.linear_model import LinearRegression, ElasticNet\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "\n",
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "pd.set_option('max_colwidth', 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.24.2'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sklearn\n",
    "sklearn.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_pickle(\"phase_2_within_sample_processed.pkl\")\n",
    "df_players = pd.read_csv(\"players.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": "true"
   },
   "source": [
    "# RME summary stats "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "count    1211.000000\n",
       "mean       27.592898\n",
       "std         4.388593\n",
       "min         5.000000\n",
       "25%        25.000000\n",
       "50%        28.000000\n",
       "75%        31.000000\n",
       "max        36.000000\n",
       "Name: cumulativeScore_RME, dtype: float64"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_players.query(\"player_valid\")['cumulativeScore_RME'].describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": "true"
   },
   "source": [
    "# Cognitive style test-retest reliability "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "optimization               121\n",
       "constraint-satisfaction    102\n",
       "Name: step1_strategyQ3, dtype: int64"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_players.dropna(subset=[\"step1_strategyQ1\", \"step2_strategyQ1\"])['step1_strategyQ3'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7117903930131004"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Overall \n",
    "(df_players.dropna(subset=[\"step1_strategyQ1\", \"step2_strategyQ1\"])['step1_strategyQ1'] == df_players.dropna(subset=[\"step1_strategyQ1\", \"step2_strategyQ1\"])['step2_strategyQ1']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.74235807860262"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(df_players.dropna(subset=[\"step1_strategyQ2\", \"step2_strategyQ2\"])['step1_strategyQ2'] == df_players.dropna(subset=[\"step1_strategyQ2\", \"step2_strategyQ2\"])['step2_strategyQ2']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7399103139013453"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(df_players.dropna(subset=[\"step1_strategyQ3\", \"step2_strategyQ3\"])['step1_strategyQ3'] == df_players.dropna(subset=[\"step1_strategyQ3\", \"step2_strategyQ3\"])['step2_strategyQ3']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7815126050420168"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# High skill (SI)\n",
    "(df_players.dropna(subset=[\"step1_strategyQ3\", \"step2_strategyQ3\"]).query(\"high_CSOP\")['step1_strategyQ3'] == df_players.dropna(subset=[\"step1_strategyQ3\", \"step2_strategyQ3\"]).query(\"high_CSOP\")['step2_strategyQ3']).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": "true"
   },
   "source": [
    "# Figure S6 summary stats "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "82.43886259323006"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['normalized_score'].quantile(0.15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>normalized_score</th>\n",
       "      <th>round_duration</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>complexity_cat</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Very low</th>\n",
       "      <td>96.008403</td>\n",
       "      <td>2.107228</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Low</th>\n",
       "      <td>93.833588</td>\n",
       "      <td>2.629592</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Moderate</th>\n",
       "      <td>89.587889</td>\n",
       "      <td>3.650680</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>High</th>\n",
       "      <td>89.645056</td>\n",
       "      <td>4.044728</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Very high</th>\n",
       "      <td>86.295181</td>\n",
       "      <td>5.057228</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                normalized_score  round_duration\n",
       "complexity_cat                                  \n",
       "Very low               96.008403        2.107228\n",
       "Low                    93.833588        2.629592\n",
       "Moderate               89.587889        3.650680\n",
       "High                   89.645056        4.044728\n",
       "Very high              86.295181        5.057228"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.groupby(\"complexity_cat\")[['normalized_score', 'round_duration']].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": "true"
   },
   "source": [
    "# Out-of-sample Q2 baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_oos = pd.read_pickle(\"prediction_features.pkl\")\n",
    "full_features = ['n_female', 'n_college', 'p1q1_exploration_div', 'p1q2_tolconflict_div', 'p1q3_optimization_div', \n",
    "                 'age_mean', 'age_std', 'skill_mean', 'skill_std', 'social_mean', 'social_std', \n",
    "                 'mean_nominal_real_gap','mean_zscore_int_solns_per_min', 'mean_zscore_num_postcomplete', 'cogstyle_diversity', \n",
    "                 'p1_efficiency_mean', 'p1_efficiency_std', 'p1_duration_mean', 'p1_duration_std']\n",
    "\n",
    "truncated_features = [\"skill_mean\", \"p1q3_optimization_div\", \"skill_std\", \"social_mean\"]\n",
    "\n",
    "\n",
    "def featimp_baseline_models(df, estimator, feature_cols, outcome_col):\n",
    "    baseline_models = []\n",
    "    \n",
    "    for loo_index in tqdm(df.index):\n",
    "        if hasattr(estimator, \"random_state\"):\n",
    "            estimator.random_state = loo_index\n",
    "            \n",
    "        baseline_models.append(deepcopy(estimator.fit(X=df.drop(loo_index)[feature_cols].values, y=df.drop(loo_index)[outcome_col])))\n",
    "        \n",
    "    return baseline_models\n",
    "\n",
    "\n",
    "def q2_score_models(df, model_list, feature_cols, outcome_col):\n",
    "    q2_means = []\n",
    "    q2_preds = []\n",
    "    \n",
    "    for loo_index in df.index:\n",
    "        q2_means.append(df.drop(loo_index)[outcome_col].mean())\n",
    "        q2_preds.append(model_list[loo_index].predict(np.array(df.iloc[loo_index][feature_cols]).reshape(1,-1))[0])\n",
    "    \n",
    "    q2 = 1 - np.sum((np.array(q2_preds) - df[outcome_col])**2) / np.sum((np.array(q2_means) - df[outcome_col])**2)\n",
    "    \n",
    "    return q2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:00<00:00, 349.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.23628584844435707\n"
     ]
    }
   ],
   "source": [
    "LR_truncated = q2_score_models(df_oos, featimp_baseline_models(df_oos, LinearRegression(), truncated_features, \"p2_cumulative_normscore\"), truncated_features, \"p2_cumulative_normscore\")\n",
    "print(LR_truncated)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:00<00:00, 514.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3019035778849961\n"
     ]
    }
   ],
   "source": [
    "#SI\n",
    "LR_full = q2_score_models(df_oos, featimp_baseline_models(df_oos, LinearRegression(), full_features, \"p2_cumulative_normscore\"), full_features, \"p2_cumulative_normscore\")\n",
    "print(LR_full)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:00<00:00, 433.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.31827168270618056\n"
     ]
    }
   ],
   "source": [
    "#SI\n",
    "EN_full = q2_score_models(df_oos, featimp_baseline_models(df_oos, ElasticNet(), full_features, \"p2_cumulative_normscore\"), full_features, \"p2_cumulative_normscore\")\n",
    "print(EN_full)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py38",
   "language": "python",
   "name": "py38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
